{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8b99e4ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "notebook_dir = os.getcwd()\n",
    "project_root_path = os.path.dirname(notebook_dir)\n",
    "sys.path.insert(0, project_root_path)\n",
    "\n",
    "from src.models import ModelXtoCResNet  # noqa: E402\n",
    "from src.preprocessing.RIVAL10 import preprocessing_rival10  # noqa: E402\n",
    "from src.utils import *  # noqa: E402, F403"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9333beb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 26384 unique images.\n",
      "Found 18 unique concepts.\n",
      "Generated one-hot training matrix with shape: (21098, 10)\n",
      "Found 21098 images.\n",
      "Processing in 330 batches of size 64 (for progress reporting)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████| 330/330 [00:45<00:00,  7.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished processing.\n",
      "Successfully transformed: 21098 images.\n",
      "Found 5286 images.\n",
      "Processing in 83 batches of size 64 (for progress reporting)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████| 83/83 [00:17<00:00,  4.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished processing.\n",
      "Successfully transformed: 5286 images.\n",
      "Dataset initialized with 21098 pre-sorted items.\n",
      "Dataset initialized with 5286 pre-sorted items.\n",
      "Split train dataset: 16878 training samples, 4220 validation samples\n"
     ]
    }
   ],
   "source": [
    "concept_labels, train_loader, val_loader, test_loader = preprocessing_rival10(training=False, class_concepts=True, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75f95a35",
   "metadata": {},
   "source": [
    "# Training Implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e0e00db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.config import RIVAL10_CONFIG as config_dict\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "72754a34",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TRIMMED_CONCEPTS = config_dict['N_TRIMMED_CONCEPTS']\n",
    "N_CLASSES = config_dict['N_CLASSES']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0858c66",
   "metadata": {},
   "source": [
    "**Find device to run model on (CPU or GPU).**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d7b1efe4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: mps\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available()\n",
    "                    else \"mps\" if torch.backends.mps.is_available()\n",
    "                    else \"cpu\")\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdc74960",
   "metadata": {},
   "source": [
    "**Instantiate the model.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3793dfc2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.\n",
      "  warnings.warn(msg)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model Instantiated (X -> C)\n"
     ]
    }
   ],
   "source": [
    "model = ModelXtoCResNet(pretrained=True,\n",
    "                freeze=True,\n",
    "                n_concepts=N_TRIMMED_CONCEPTS,\n",
    "                label_mode=True,\n",
    "                n_classes=N_CLASSES)\n",
    "\n",
    "model = model.to(device)\n",
    "print(\"Model Instantiated (X -> C)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24b06149",
   "metadata": {},
   "source": [
    "### Loss\n",
    "We use weighted loss.\n",
    "\n",
    "`BCEWithLogitsLoss()` performs 2 steps:\n",
    "1. $\\sigma(x)$\n",
    "    - Applies the sigmoid function to the logits to get probabilities.\n",
    "2. $\\text{BCE}(\\sigma(x), y) = y \\cdot \\text{log}(\\sigma(x)) + (1-y) \\cdot (1-\\text{log}(\\sigma(x)))$\n",
    "    - Compute binary cross-entropy between output probabilities ($\\sigma(x)$) and ground truths ($y$)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "22440cca",
   "metadata": {},
   "outputs": [],
   "source": [
    "attr_criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5b0878f",
   "metadata": {},
   "source": [
    "### Optimiser\n",
    "Use same settings as used in CBM repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "008ebdc2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimizer and Scheduler Ready\n"
     ]
    }
   ],
   "source": [
    "lr = 0.01\n",
    "weight_decay = 0.00004 # same as lambda in L2-regularisation\n",
    "\n",
    "optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()),\n",
    "                    lr=lr,\n",
    "                    momentum=0.9,\n",
    "                    weight_decay=weight_decay)\n",
    "\n",
    "# scheduler_step = n -> decrease the LR every n epochs\n",
    "scheduler_step = 1000\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=0.1)\n",
    "\n",
    "print(\"Optimizer and Scheduler Ready\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "144d34f8",
   "metadata": {},
   "source": [
    "### Training and Validation Loops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0eb2c7ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 2\n",
    "log_interval = 50\n",
    "\n",
    "best_val_acc = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1ae2e2ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Epoch 1/2 ---\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   9%|▊         | 23/264 [01:31<03:27,  1.16it/s, acc=0.8308, loss=0.6878] libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe\n",
      "Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x11f05a8e0>\n",
      "Traceback (most recent call last):\n",
      "  File \"/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/data/dataloader.py\", line 1618, in __del__\n",
      "    self._shutdown_workers()\n",
      "  File \"/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/data/dataloader.py\", line 1582, in _shutdown_workers\n",
      "    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)\n",
      "  File \"/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/process.py\", line 149, in join\n",
      "    res = self._popen.wait(timeout)\n",
      "          ^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/popen_fork.py\", line 40, in wait\n",
      "    if not wait([self.sentinel], timeout):\n",
      "           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/multiprocessing/connection.py\", line 948, in wait\n",
      "    ready = selector.select(timeout)\n",
      "            ^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/selectors.py\", line 415, in select\n",
      "    fd_event_list = self._selector.poll(timeout)\n",
      "                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
      "  File \"/Users/pb/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/data/_utils/signal_handling.py\", line 73, in handler\n",
      "    _error_if_any_worker_fails()\n",
      "RuntimeError: DataLoader worker (pid 2979) is killed by signal: Abort trap: 6. \n",
      "                                                                                   \r"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[11], line 7\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m--- Epoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m ---\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      6\u001b[0m \u001b[38;5;66;03m# Train\u001b[39;00m\n\u001b[0;32m----> 7\u001b[0m train_loss, train_acc \u001b[38;5;241m=\u001b[39m \u001b[43mrun_epoch_x_to_y\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattr_criterion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_training\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m Train Summary | Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m | Acc: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_acc\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.3f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m     10\u001b[0m \u001b[38;5;66;03m# Validate\u001b[39;00m\n",
      "File \u001b[0;32m~/Documents/career/lab_ujm/hybrid-cbm-prototype-model/src/training/phase_1_training.py:127\u001b[0m, in \u001b[0;36mrun_epoch_x_to_y\u001b[0;34m(model, loader, criterion, optimizer, is_training, device, verbose, return_outputs)\u001b[0m\n\u001b[1;32m    124\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_training:\n\u001b[1;32m    125\u001b[0m     optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m--> 127\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(logits, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)):\n\u001b[1;32m    129\u001b[0m     logits \u001b[38;5;241m=\u001b[39m logits[\u001b[38;5;241m0\u001b[39m]  \u001b[38;5;66;03m# or select the correct output for classification\u001b[39;00m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1739\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1737\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1738\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1739\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1750\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1745\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1746\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1747\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1748\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1749\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1750\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1752\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1753\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[0;32m~/Documents/career/lab_ujm/hybrid-cbm-prototype-model/src/models/resnet_model.py:53\u001b[0m, in \u001b[0;36mCustomResNet.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     51\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer2(x)\n\u001b[1;32m     52\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer3(x)\n\u001b[0;32m---> 53\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayer4\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     55\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mavgpool(x)\n\u001b[1;32m     56\u001b[0m x \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mflatten(x, \u001b[38;5;241m1\u001b[39m)\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1739\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1737\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1738\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1739\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1750\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1745\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1746\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1747\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1748\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1749\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1750\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1752\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1753\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/container.py:250\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    248\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m    249\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 250\u001b[0m         \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    251\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1739\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1737\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1738\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1739\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1750\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1745\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1746\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1747\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1748\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1749\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1750\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1752\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1753\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torchvision/models/resnet.py:147\u001b[0m, in \u001b[0;36mBottleneck.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    144\u001b[0m identity \u001b[38;5;241m=\u001b[39m x\n\u001b[1;32m    146\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv1(x)\n\u001b[0;32m--> 147\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbn1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    148\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrelu(out)\n\u001b[1;32m    150\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconv2(out)\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1739\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1737\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1738\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1739\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1750\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1745\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1746\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1747\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1748\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1749\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1750\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1752\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1753\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/batchnorm.py:193\u001b[0m, in \u001b[0;36m_BatchNorm.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    186\u001b[0m     bn_training \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_mean \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;129;01mand\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrunning_var \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m    188\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    189\u001b[0m \u001b[38;5;124;03mBuffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be\u001b[39;00m\n\u001b[1;32m    190\u001b[0m \u001b[38;5;124;03mpassed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are\u001b[39;00m\n\u001b[1;32m    191\u001b[0m \u001b[38;5;124;03mused for normalization (i.e. in eval mode when buffers are not None).\u001b[39;00m\n\u001b[1;32m    192\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m--> 193\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_norm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    194\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    195\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;66;43;03m# If buffers are not to be tracked, ensure that they won't be updated\u001b[39;49;00m\n\u001b[1;32m    196\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrunning_mean\u001b[49m\n\u001b[1;32m    197\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack_running_stats\u001b[49m\n\u001b[1;32m    198\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    199\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrunning_var\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrack_running_stats\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    200\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    201\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    202\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbn_training\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    203\u001b[0m \u001b[43m    \u001b[49m\u001b[43mexponential_average_factor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    204\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    205\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/functional.py:2822\u001b[0m, in \u001b[0;36mbatch_norm\u001b[0;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[1;32m   2819\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training:\n\u001b[1;32m   2820\u001b[0m     _verify_batch_size(\u001b[38;5;28minput\u001b[39m\u001b[38;5;241m.\u001b[39msize())\n\u001b[0;32m-> 2822\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_norm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2823\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2824\u001b[0m \u001b[43m    \u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2825\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2826\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrunning_mean\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2827\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrunning_var\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2828\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2829\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmomentum\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2830\u001b[0m \u001b[43m    \u001b[49m\u001b[43meps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2831\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackends\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcudnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menabled\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2832\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "from src.training import run_epoch_x_to_y\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print(f\"--- Epoch {epoch+1}/{epochs} ---\")\n",
    "\n",
    "    # Train\n",
    "    train_loss, train_acc = run_epoch_x_to_y(model, train_loader, attr_criterion, optimizer, is_training=True, device=device, verbose=True)\n",
    "    print(f'Epoch {epoch+1} Train Summary | Loss: {train_loss:.4f} | Acc: {train_acc:.3f}')\n",
    "\n",
    "    # Validate\n",
    "    if val_loader:\n",
    "        with torch.no_grad():\n",
    "            val_loss, val_acc = run_epoch_x_to_y(model, val_loader, attr_criterion, optimizer, device=device, verbose=True)\n",
    "\n",
    "        print(f'Epoch {epoch+1} Val Summary   | Loss: {val_loss:.4f} | Acc: {val_acc:.3f}')\n",
    "\n",
    "        # Save best model based on validation accuracy\n",
    "        if val_acc > best_val_acc:\n",
    "            print(f\"Validation accuracy improved ({best_val_acc:.3f} -> {val_acc:.3f}). Saving model...\")\n",
    "            best_val_acc = val_acc\n",
    "            # torch.save(model, 'x_to_c_best_model.pth')\n",
    "            # print(\"Model saved to x_to_c_best_model.pth\")\n",
    "\n",
    "    # Scheduler step\n",
    "    scheduler.step()\n",
    "    print(f\"Current LR: {optimizer.param_groups[0]['lr']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19e51089",
   "metadata": {},
   "outputs": [],
   "source": [
    "if test_loader:\n",
    "    with torch.no_grad():\n",
    "\n",
    "        test_loss, test_acc = run_epoch_x_to_y(model, test_loader, attr_criterion, optimizer, device=device, verbose=True)\n",
    "\n",
    "# print(f\"Shuffled labels shape: {shuffled_img_labels.shape}\")\n",
    "print(f'Best Model Summary   | Loss: {test_loss:.4f} | Acc: {test_acc:.3f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d30f3352",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dab91783",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d652f20",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
